import matplotlib.pyplot as plt
import numpy as np
# import torch
import sys
import os
import random


def generate_orthog(N):
    H = np.random.rand(N, N)+np.eye(N)
    u, s, vh = np.linalg.svd(H, full_matrices=False)
    mat = u @ vh
    return mat


def generate_adj(orthomat, nums, sparsity_thres=0.05):
    N, N = orthomat.shape
    adj_list = []
    for i in range(nums):
        eigen = np.absolute(np.random.randn(N))
        adj = orthomat@(eigen*np.eye(N))@(orthomat.T)
        for ids in range(N):
            adj[ids, ids] = 0
        thres = 0
        adj[adj < thres] = 0
        sparsity = (adj > 0).sum()/(N*N)
        while sparsity > sparsity_thres:
            if thres > 0.99:
                print("singular cor mat for sparsity thres = {:.3f}".format(
                    sparsity_thres))
                return None
            thres = thres + 0.02
            adj[adj < thres] = 0
            sparsity = (adj > 0).sum()/(N*N)
        assert np.diag(adj).sum() == 0
        adj_list.append(adj)
    return adj_list


def adj_interpolation(adj_list, cycle, sparsity_thres):
    num_adj = len(adj_list)
    N, N = adj_list[0].shape
    trans_time = cycle//num_adj
    interadj_list = []
    for idx in range(cycle):
        stage = idx//trans_time
        current_adj = adj_list[stage]
        thres = 0
        current_adj[current_adj < thres] = 0
        sparsity = (current_adj > 0).sum()/(N*N)
        while sparsity > sparsity_thres:
            if thres > 0.99:
                print("singular cor mat for sparsity thres = {:.3f}".format(
                    sparsity_thres))
                return None
            thres = thres + 0.02
            current_adj[current_adj < thres] = 0
            sparsity = (current_adj > 0).sum()/(N*N)
        # normalization
        assert(np.diag(current_adj).sum() == 0), "diag sum {:.4f}".format(
            np.diag(current_adj).sum())
        current_adj = current_adj+np.eye(N)
        current_adj = current_adj/current_adj.sum(axis=1).reshape(N,1)
        interadj_list.append(current_adj)
    return interadj_list


def source_signal(num_series, T):
    signal_vector = []
    for ids in range(num_series):
        signal_vector.append(np.sin(ids*(T**0.5)))
    return np.array(signal_vector)


def stepwithadj(state_vector, adj, sigma):
    N = len(state_vector)
    assert np.isnan(adj).sum() < 1
    state_vector = state_vector.reshape(N, 1)
    agg_vector = (adj@state_vector).squeeze(1)
    new_state = np.zeros((N))
    for idx in range(N):
        state = np.random.normal(agg_vector[idx],sigma)
        new_state[idx] = state
    return new_state


def synthetic(num_series, cluster, length, cycle, sigma ,sparsity):
    ortho = generate_orthog(num_series)
    adj_list = generate_adj(ortho, cluster, sparsity)
    mts_data = np.zeros((num_series, length))
    if cluster > 1:
        dynamic_adj = adj_interpolation(adj_list, cycle, sparsity)
        for step in range(length):
            if step%cycle == 0:
                state = np.random.choice([-1,-0.5,0.5,1],(num_series))
                mts_data[:, step] = state
            else:
                state =  stepwithadj(
                    state, dynamic_adj[(step-1) % cycle], sigma)
                mts_data[:, step] = state
        return mts_data,dynamic_adj
    else:
        adj = adj_list[0]
        assert np.absolute(np.diag(adj)).sum() == 0
        adj = adj+np.eye(num_series)
        adj = adj/adj.sum(axis=1).reshape(num_series,1)
        for step in range(length):
            if step%cycle == 0:
                state = np.random.choice([-1,-0.5,0.5,1],(num_series))
                mts_data[:, step] = state
            else:
                state = stepwithadj(state, adj, sigma)
                mts_data[:, step] = state
        return mts_data, [adj]


def smoothmts(mts_data, smooth_rate):
    N, T = mts_data.shape
    new_data = np.zeros(mts_data.shape)
    state = mts_data[:, 0]
    new_data[:, 0] = state
    for step in range(1, T):
        state = smooth_rate*state+(1-smooth_rate)*mts_data[:, step]
        new_data[:, step] = state
    return new_data


if __name__ == "__main__":
    seed = 100
    np.random.seed(seed)
    random.seed(seed)
    root_dir = r"data/synthetic"
    T_length = 2400

    num_series = 100

    cluster = 6

    cycle = 120
    sigma = 0.01
    sparsity = 0.05
    smooth = 0.4
    A = np.random.randn(num_series)
    B = np.random.randn(num_series)
    mts_data, adj_list = synthetic(
        num_series, cluster, T_length, cycle, sigma, sparsity)
    adj_dict = {str(idx): adj_list[idx] for idx in range(len(adj_list))}
    adj_name = "rwadj_cluster{}_cycle{}_numsr{}_sigma{}_sparisity{}_smooth{}_seed{}.npz".format(
        cluster, cycle, num_series, sigma, sparsity, smooth, seed)
    f_name = "conditionrw_cluster{}_cycle{}_numsr{}_sigma{}_sparisity{}_smooth{}_seed{}.npy".format(
        cluster, cycle, num_series, sigma, sparsity, smooth, seed)
    # mts_data = smoothmts(mts_data, smooth)
    np.savez(os.path.join(root_dir, adj_name), **adj_dict)
    np.save(os.path.join(root_dir, f_name), mts_data)
